import random
import torch

from PIL import Image, ImageFilter, ImageOps
import torchvision.transforms as transforms


class ToRGB:
    def __call__(self, x):
        return x.convert("RGB")


class Solarization(object):
    def __call__(self, x):
        return ImageOps.solarize(x)


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[0.1, 2.0]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


def moco_transform(aug_plus):
    if aug_plus:
        transform = transforms.Compose([
                ToRGB(),
                transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        ])

    else:
        transform = transforms.Compose([
            ToRGB(),
            transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
    return transform


def byol_transform():
    transform_q = transforms.Compose(
        [
            ToRGB(),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    transform_k = transforms.Compose(
        [
            ToRGB(),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([Solarization()], p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return [transform_q, transform_k]


def simclr_transform():
    transform = transforms.Compose([
        ToRGB(),
        transforms.RandomResizedCrop(224, scale=(0.08, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        ], p=0.8),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform


def barlowtwins_transform():
    transform_q = transforms.Compose([
        ToRGB(),
        transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                    saturation=0.2, hue=0.1)],
            p=0.8
        ),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    transform_k = transforms.Compose([
        ToRGB(),
        transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                    saturation=0.2, hue=0.1)],
            p=0.8
        ),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1),
        transforms.RandomApply([Solarization()], p=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return [transform_q, transform_k]


def swav_transform(num_crops=[2], crop_sizes=[224], min_scale_crops=[0.14], max_scale_crops=[1.0],
                   color_jitter_scale=1.0):

    color_jitter = transforms.ColorJitter(0.8 * color_jitter_scale, 0.8 * color_jitter_scale,
                                          0.8 * color_jitter_scale, 0.2 * color_jitter_scale)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])

    trans = []
    for i in range(len(num_crops)):
        randomresizedcrop = transforms.RandomResizedCrop(
            crop_sizes[i],
            scale=(min_scale_crops[i], max_scale_crops[i]),
        )
        trans.extend([
            transforms.Compose(
                [
                    ToRGB(),
                    randomresizedcrop,
                    transforms.RandomHorizontalFlip(p=0.5),
                    color_distort,
                    transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]
            )] * num_crops[i])
    return trans


def typical_imagenet_transform(train):
    if train:
        transform = transforms.Compose(
            [
                ToRGB(),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    else:
        transform = transforms.Compose(
            [
                ToRGB(),
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
    return transform


def resizecrop_transform():
    transform = transforms.Compose(
        [
            ToRGB(),
            transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
            transforms.RandomHorizontalFlip(),
        ]
    )
    return transform


def sym_byol_color_transform():
    transform = transforms.Compose(
        [
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
            transforms.RandomApply([Solarization()], p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )
    return transform


class MaskedAugmentation(object):
    def __init__(self, mask_ratio, window_size, image_size=224):
        self.position_transform = resizecrop_transform()
        self.color_transform = sym_byol_color_transform()
        self.mask_ratio = mask_ratio
        self.window_size = window_size
        self.feat_size = image_size // window_size

    def random_mask(self):
        ps = self.window_size
        pn = self.feat_size
        mask = (torch.rand(pn, pn) >= self.mask_ratio).int()
        return mask.view(pn, 1, pn).expand(-1, ps, -1).reshape(pn*ps, pn, 1).expand(-1, -1, ps).reshape(pn*ps, pn*ps), mask

    def __call__(self, image):
        image = self.position_transform(image)
        image = [self.color_transform(image) for _ in range(2)]
        mask, feat_mask = self.random_mask()
        image[0] *= mask.unsqueeze(0)
        image.append(feat_mask)
        return image
